# Copyright (c) OpenMMLab. All rights reserved.
import io
from contextlib import contextmanager

import mmengine.fileio as fileio
from mmengine.fileio import LocalBackend, PetrelBackend, get_file_backend


def patch_func(module, fn_name_to_wrap):
    backup = getattr(patch_func, "_backup", [])
    fn_to_wrap = getattr(module, fn_name_to_wrap)

    def wrap(fn_new):
        setattr(module, fn_name_to_wrap, fn_new)
        backup.append((module, fn_name_to_wrap, fn_to_wrap))
        setattr(fn_new, "_fallback", fn_to_wrap)
        setattr(patch_func, "_backup", backup)
        return fn_new

    return wrap


@contextmanager
def patch_fileio(global_vars=None):
    if getattr(patch_fileio, "_patched", False):
        # Only patch once, avoid error caused by patch nestly.
        yield
        return
    import builtins

    @patch_func(builtins, "open")
    def open(file, mode="r", *args, **kwargs):
        backend = get_file_backend(file)
        if isinstance(backend, LocalBackend):
            return open._fallback(file, mode, *args, **kwargs)
        if "b" in mode:
            return io.BytesIO(backend.get(file, *args, **kwargs))
        else:
            return io.StringIO(backend.get_text(file, *args, **kwargs))

    if global_vars is not None and "open" in global_vars:
        bak_open = global_vars["open"]
        global_vars["open"] = builtins.open

    import os

    @patch_func(os.path, "join")
    def join(a, *paths):
        backend = get_file_backend(a.decode("utf-8") if isinstance(a, bytes) else a)
        if isinstance(backend, LocalBackend):
            return join._fallback(a, *paths)
        paths = [item.lstrip("./") for item in paths if len(item) > 0]
        return backend.join_path(a, *paths)

    @patch_func(os.path, "isdir")
    def isdir(path):
        backend = get_file_backend(path)
        if isinstance(backend, LocalBackend):
            return isdir._fallback(path)

        return backend.isdir(path)

    @patch_func(os.path, "isfile")
    def isfile(path):
        backend = get_file_backend(path)
        if isinstance(backend, LocalBackend):
            return isfile._fallback(path)

        return backend.isfile(path)

    @patch_func(os.path, "exists")
    def exists(path):
        backend = get_file_backend(path)
        if isinstance(backend, LocalBackend):
            return exists._fallback(path)
        return backend.exists(path)

    @patch_func(os, "mkdir")
    def mkdir(path, *args, **kwargs):
        backend = get_file_backend(path)
        if isinstance(backend, LocalBackend):
            return mkdir._fallback(path, *args, **kwargs)

    @patch_func(os, "makedirs")
    def makedirs(path, *args, **kwargs):
        backend = get_file_backend(path)
        if isinstance(backend, LocalBackend):
            return makedirs._fallback(path, *args, **kwargs)

    @patch_func(os, "listdir")
    def listdir(path):
        backend = get_file_backend(path)
        if isinstance(backend, LocalBackend):
            return listdir._fallback(path)
        return backend.list_dir_or_file(path)

    @patch_func(os, "chmod")
    def chmod(path, *args, **kwargs):
        backend = get_file_backend(path)
        if isinstance(backend, LocalBackend):
            return chmod._fallback(path, *args, **kwargs)

    @patch_func(os, "stat")
    def stat(path, *args, **kwargs):
        backend = get_file_backend(path)
        if isinstance(backend, LocalBackend):
            return stat._fallback(path, *args, **kwargs)

    import glob as glob_pkg

    @patch_func(glob_pkg, "glob")
    def glob(pathname, *, recursive=False):
        backend = get_file_backend(pathname)
        if isinstance(backend, LocalBackend):
            return glob._fallback(pathname, recursive=recursive)

        if pathname.endswith("*_optim_states.pt"):
            import os

            pathname = os.path.split(pathname)[0]
            files = backend.list_dir_or_file(pathname, recursive=recursive)
            files = [
                os.path.join(pathname, f)
                for f in files
                if f.endswith("_optim_states.pt")
            ]
        elif pathname.endswith("*_model_states.pt"):
            import os

            pathname = os.path.split(pathname)[0]
            files = backend.list_dir_or_file(pathname, recursive=recursive)
            files = [
                os.path.join(pathname, f)
                for f in files
                if f.endswith("_model_states.pt")
            ]
        elif "*" in pathname:
            raise NotImplementedError
        else:
            files = backend.list_dir_or_file(pathname, recursive=recursive)

        return files

    import filecmp

    @patch_func(filecmp, "cmp")
    def cmp(f1, f2, *args, **kwargs):
        with fileio.get_local_path(f1) as f1, fileio.get_local_path(f2) as f2:
            return cmp._fallback(f1, f2, *args, **kwargs)

    import shutil

    @patch_func(shutil, "copy")
    def copy(src, dst, **kwargs):
        from pathlib import Path

        if isinstance(src, Path):
            src = str(src).replace(":/", "://")
        if isinstance(dst, Path):
            dst = str(dst).replace(":/", "://")

        src_backend = get_file_backend(src)
        dst_backend = get_file_backend(dst)

        if isinstance(src_backend, LocalBackend) and isinstance(
            dst_backend, LocalBackend
        ):
            return copy._fallback(src, dst, **kwargs)
        elif isinstance(src_backend, LocalBackend) and isinstance(
            dst_backend, PetrelBackend
        ):
            return dst_backend.copyfile_from_local(str(src), str(dst))
        elif isinstance(src_backend, PetrelBackend) and isinstance(
            dst_backend, LocalBackend
        ):
            return src_backend.copyfile_to_local(str(src), str(dst))

    import torch

    @patch_func(torch, "load")
    def load(f, *args, **kwargs):
        if isinstance(f, str):
            f = io.BytesIO(fileio.get(f))
        return load._fallback(f, *args, **kwargs)

    @patch_func(torch, "save")
    def save(obj, f, *args, **kwargs):
        backend = get_file_backend(f)
        if isinstance(backend, LocalBackend):
            return save._fallback(obj, f, *args, **kwargs)

        with io.BytesIO() as buffer:
            save._fallback(obj, buffer, *args, **kwargs)
            buffer.seek(0)
            backend.put(buffer, f)

        # from tempfile import TemporaryDirectory
        # import os
        # with TemporaryDirectory(dir='/dev/shm') as tmpdir:
        #     suffix = os.path.split(f)[-1]
        #     tmppath = os.path.join._fallback(tmpdir, suffix)
        #     from mmengine import print_log
        #     print_log('write to tmp dir', logger='current')
        #     save._fallback(obj, tmppath, *args, **kwargs)
        #     print_log('write to ceph', logger='current')

        #     with open(tmppath, 'rb') as buffer:
        #         backend.put(buffer, f)

    from sentencepiece import SentencePieceProcessor

    @patch_func(SentencePieceProcessor, "LoadFromFile")
    def LoadFromFile(cls, path):
        if path:
            backend = get_file_backend(path)
            if isinstance(backend, LocalBackend):
                return LoadFromFile._fallback(cls, path)
            from tempfile import TemporaryDirectory

            with TemporaryDirectory() as tmpdir:
                local_path = backend.copyfile_to_local(path, tmpdir)
                loaded_file = LoadFromFile._fallback(cls, local_path)
            return loaded_file
        else:
            return LoadFromFile._fallback(cls, path)

    try:
        setattr(patch_fileio, "_patched", True)
        yield
    finally:
        for patched_fn in patch_func._backup:
            (module, fn_name_to_wrap, fn_to_wrap) = patched_fn
            setattr(module, fn_name_to_wrap, fn_to_wrap)
        if global_vars is not None and "open" in global_vars:
            global_vars["open"] = bak_open
        setattr(patch_fileio, "_patched", False)


def patch_hf_auto_from_pretrained(petrel_hub):
    if hasattr(patch_hf_auto_from_pretrained, "_patched"):
        return

    from peft import PeftModel
    from transformers import (
        AutoConfig,
        AutoFeatureExtractor,
        AutoImageProcessor,
        AutoModelForCausalLM,
        AutoProcessor,
        AutoTokenizer,
        ImageProcessingMixin,
        PreTrainedModel,
        PreTrainedTokenizerBase,
        ProcessorMixin,
    )
    from transformers.models.auto.auto_factory import _BaseAutoModelClass

    target_cls = list(_BaseAutoModelClass.__subclasses__())
    target_cls.extend([AutoModelForCausalLM] + AutoModelForCausalLM.__subclasses__())
    target_cls.extend([AutoConfig] + AutoConfig.__subclasses__())
    target_cls.extend([AutoTokenizer] + AutoTokenizer.__subclasses__())
    target_cls.extend([AutoImageProcessor] + AutoImageProcessor.__subclasses__())
    target_cls.extend([AutoFeatureExtractor] + AutoFeatureExtractor.__subclasses__())
    target_cls.extend([AutoProcessor] + AutoProcessor.__subclasses__())
    target_cls.extend(
        [PreTrainedTokenizerBase] + PreTrainedTokenizerBase.__subclasses__()
    )
    target_cls.extend([ImageProcessingMixin] + ImageProcessingMixin.__subclasses__())
    target_cls.extend([PreTrainedModel] + PreTrainedModel.__subclasses__())
    target_cls.extend([ProcessorMixin] + ProcessorMixin.__subclasses__())
    target_cls.extend([PeftModel] + PeftModel.__subclasses__())

    import os

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
        with patch_fileio():
            model_path = pretrained_model_name_or_path
            model_path = os.path.join(petrel_hub, model_path)
            obj = cls._from_pretrained(model_path, *args, **kwargs)
        return obj

    for cls in set(target_cls):
        if not hasattr(cls, "_from_pretrained"):
            cls._from_pretrained = cls.from_pretrained
            cls.from_pretrained = from_pretrained

    patch_hf_auto_from_pretrained._patched = True


def patch_hf_save_pretrained():
    if hasattr(patch_hf_save_pretrained, "_patched"):
        return

    import torch
    from peft import PeftModel
    from transformers import (
        AutoConfig,
        AutoTokenizer,
        PreTrainedModel,
        PreTrainedTokenizerBase,
    )
    from transformers.models.auto.auto_factory import _BaseAutoModelClass

    target_cls = []
    target_cls.extend([AutoConfig] + AutoConfig.__subclasses__())
    target_cls.extend([AutoTokenizer] + AutoTokenizer.__subclasses__())
    target_cls.extend(
        [PreTrainedTokenizerBase] + PreTrainedTokenizerBase.__subclasses__()
    )
    target_cls.extend([PreTrainedModel] + PreTrainedModel.__subclasses__())

    target_cls.extend([_BaseAutoModelClass] + _BaseAutoModelClass.__subclasses__())
    target_cls.extend([PeftModel] + PeftModel.__subclasses__())

    def _patch_wrap(method):
        def wrapped_method(self, *args, **kwargs):
            with patch_fileio():
                kwargs["save_function"] = torch.save
                kwargs["safe_serialization"] = False

                obj = method(self, *args, **kwargs)
            return obj

        return wrapped_method

    for cls in set(target_cls):
        if hasattr(cls, "save_pretrained"):
            cls.save_pretrained = _patch_wrap(cls.save_pretrained)

    patch_hf_save_pretrained._patched = True


def patch_deepspeed_engine():
    if hasattr(patch_deepspeed_engine, "_patched"):
        return

    def _copy_recovery_script(self, save_path):
        import os
        from shutil import copyfile

        from deepspeed.utils import zero_to_fp32
        from mmengine import PetrelBackend, get_file_backend

        script = "zero_to_fp32.py"

        src = zero_to_fp32.__file__
        dst = os.path.join(save_path, script)

        backend = get_file_backend(save_path)
        if isinstance(backend, PetrelBackend):
            backend.copyfile_from_local(src, dst)
        else:
            copyfile(src, dst)
            self._change_recovery_script_permissions(dst)

    from deepspeed.runtime.engine import DeepSpeedEngine

    DeepSpeedEngine._copy_recovery_script = _copy_recovery_script

    patch_deepspeed_engine._patched = True
